# pundits predict ideology with pivot scaling
# with parties
# using hiernet instead of regular lasso

library(tidyverse)
library(data.table)
library(dtplyr)
library(network)

library(hierNet)

require(foreach)
require(doMC)
registerDoMC(cores = 11)

source("pundits_functions.R")

set.seed(1111)

run_predideo <- function(input, 
                         outcome, 
                         includevars,
                         save_path = "../output/",
                         name,
                         seed = 11111){
  require(hierNet)
  set.seed(seed)
  
  message('shaping input')
  X <- input %>%
    dplyr::select(all_of(includevars)) %>%
    as.matrix()
  
  X_scale <- as.matrix(data.frame(scale(X)))
  
  Y <- outcome
  if(any(is.na(Y))){
    message(paste0("removing ", sum(is.na(Y)), " observations due to missingness"))
    Y <- Y[-which(is.na(Y))]
    X_scale <- X_scale[-which(is.na(Y)),]
  }
  
  message("modeling...")
  mod <- hierNet.path(y = Y,
                        x = X_scale,
                        lamlist = seq(from = 10, to = 200, by = 10),
                        strong = TRUE,
                        trace = 1)
    
  save(mod, file = paste0(save_path, "mod_", name, ".RData"))
    
  mod.cv <- hierNet.cv(mod,
                          x = X_scale,
                          y = Y,
                          trace = 1)
    
  save(mod.cv, file = paste0(save_path, "modcv_", name, ".RData"))
}


pundits_agg0 <- readr::read_csv("../data/pundits_parrot_aggregated_imp0.csv")
                                            
ideal.points_d1 <- readr::read_csv("../data/ideal_points_masked.csv")
      
which.na.ts <- which(is.na(ideal.points_d1$tweetscore))

message("running heldout loop...")

topics <- c('china','class','climate','conservative',
                                    'democrat',
                                    'far_left','far_right','gender','guns',
                                    'health_care_insurance','immigration',
                                    'iran','israel','lgbt','liberal',
                                    'mueller','progressive','race',
                                    'reproductive_health','republican',
                                    'taxes_spending','trade')

foreach(t = topics) %dopar% {
    includes <- names(pundits_agg0)[-1]
    includes <-includes[-which(grepl(t, includes))]
    
    run_predideo(input = pundits_agg0,
             outcome = ideal.points_d1$ideal.point,
             includevars = includes,
                 save_path = "../output/heldout_hiernets/",
            name = paste0("heldout_", t))
    }